Skip to content

Add SANA-WM camera-controlled image-to-video pipeline#13881

Open
lawrence-cj wants to merge 25 commits into
huggingface:mainfrom
lawrence-cj:feat/sana-wm-diffusers-cleanup
Open

Add SANA-WM camera-controlled image-to-video pipeline#13881
lawrence-cj wants to merge 25 commits into
huggingface:mainfrom
lawrence-cj:feat/sana-wm-diffusers-cleanup

Conversation

@lawrence-cj

@lawrence-cj lawrence-cj commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Hi @sayakpaul @dg845 , Long time no see. Hoping your are doing great. ♥️

Adds SANA-WM, the camera-controlled image-to-video world model from NVIDIA + MIT HAN Lab, as a first-class diffusers pipeline and transformer. Given a first-frame image, a text prompt, and a camera trajectory (explicit c2w poses or a WASD/IJKL action-DSL string), the pipeline generates a video whose motion follows the requested camera path. Trained natively for minute-scale generation at 704×1280.

The pipeline runs in two stages:

  1. Stage 1 — SanaWMTransformer3DModel. A 1.6B-parameter bidirectional DiT with GDN-Triton linear attention and a UCPE camera-control branch; samples with an LTX-style flow-matching Euler scheduler at per-token timesteps. The first latent frame is the conditioning anchor.
  2. Stage 2 — SanaWMLTX2Refiner (optional). A chunk-causal AR refiner that wraps diffusers' LTX2VideoTransformer3DModel + LTX2TextConnectors + Gemma-3 text encoder. Processes 3 latent frames at a time with a sliding window of [source_sink + recent_history + active_block] K/V, so per-block compute is bounded and total refinement cost is linear in video length.

Both stages decode through AutoencoderKLLTX2Video.

Layout

src/diffusers/
├── models/transformers/
│   ├── transformer_sana_wm.py          # SanaWMTransformer3DModel + blocks + helpers
│   └── transformer_sana_wm_kernels.py  # fused Triton kernels + camera math
└── pipelines/sana_wm/
    ├── __init__.py
    ├── pipeline_sana_wm.py             # SanaWMPipeline
    ├── pipeline_output.py              # SanaWMPipelineOutput
    ├── refiner.py                      # SanaWMLTX2Refiner + RefinerChunkRunner
    └── cam_utils.py                    # action DSL, intrinsics, resize+crop, Plücker/raymap

scripts/sana_wm/convert_sana_wm_to_diffusers.py
docs/source/en/api/{pipelines/sana_wm.md, models/sana_wm_transformer3d.md}

Usage

import torch
from PIL import Image
from diffusers import SanaWMPipeline
from diffusers.utils import export_to_video

pipe = SanaWMPipeline.from_pretrained(
    "Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.vae.to(torch.float32)
pipe.enable_model_cpu_offload()

out = pipe(
    image=Image.open("input.png").convert("RGB"),
    prompt="A car driving across a vast desert plain at golden hour.",
    action="w-80,jw-40,w-40",                    # WASD-style action DSL
    intrinsics=[800.0, 800.0, 845.0, 464.0],      # fx, fy, cx, cy in original-image pixels
    num_frames=161,
    num_inference_steps=60,
)
export_to_video(list(out.frames), "sana_wm.mp4", fps=16)

Demo

5-second sample (30 stage-1 steps + 3-step distilled AR refiner, official asset/sana_wm/demo_0 inputs, 704×1280 @ 16 fps) :

sana_wm_5s.mp4

Smoke tests

End-to-end on 1× H100 80GB with `enable_model_cpu_offload` and the official `asset/sana_wm/demo_0.{png,txt,_pose.npy,_intrinsics.npy}`:

Duration Frames Stage-1 (30 steps) Refiner (AR, 3 blocks) Output
5s 80 1:11 5:24 / step 525 KB
10s 160 1:11 28:55 (7 blocks) 1.4 MB
20s 320 1:57 ≈ 4 min / block (14) 3.2 MB
50s 800 5:33 30:46 (34 blocks) 6.3 MB

Checkpoint conversion

scripts/sana_wm/convert_sana_wm_to_diffusers.py --src Efficient-Large-Model/SANA-WM_bidirectional --dst /local/path converts the public release into a `from_pretrained`-loadable directory (VAE, Gemma-2 tokenizer + text_encoder, transformer, scheduler, refiner subfolders, top-level `model_index.json`).

Related

Paper: https://arxiv.org/abs/2605.15178

HaoyiZhu and others added 4 commits June 1, 2026 01:28
…line

Adds the public SANA-WM bidirectional camera-controlled image-to-video
model as a first-class diffusers pipeline + transformer. Layout mirrors
``sana_video``: the model lives under ``src/diffusers/models/transformers/``
as a near-single-file (kernels split off so the ``@triton.jit`` decorators
don't drown the model body); the pipeline lives under
``src/diffusers/pipelines/sana_wm/``.

Files added:

  src/diffusers/models/transformers/
  ├── transformer_sana_wm.py         # SanaWMTransformer3DModel + blocks + helpers
  └── transformer_sana_wm_kernels.py # fused Triton kernels + camera math

  src/diffusers/pipelines/sana_wm/
  ├── __init__.py
  ├── pipeline_sana_wm.py
  ├── pipeline_output.py
  ├── refiner.py
  └── cam_utils.py

Pipeline architecture:
* Stage 1: 1600M ``SanaWMTransformer3DModel`` DiT with bidirectional
  GDN-Triton linear attention + UCPE camera-control branch, LTX-style
  flow-matching Euler scheduler with per-token timesteps.
* Stage 2: LTX-2 sink-bidirectional Euler refiner (3 distilled sigma
  steps, reuses diffusers' ``LTX2VideoTransformer3DModel`` +
  ``LTX2TextConnectors`` + Gemma-3 text encoder).
* Decode through the LTX-2 VAE (``AutoencoderKLLTX2Video``).

One-line usage:

  pipe = SanaWMPipeline.from_pretrained(
      "Efficient-Large-Model/SANA-WM_bidirectional-diffusers",
      torch_dtype=torch.bfloat16,
  ).to("cuda")
  out = pipe(image=img, prompt="...", action="w-80,jw-40,w-40",
             intrinsics=[fx, fy, cx, cy])

End-to-end smoke test (stage-1 + refiner + VAE decode) passes on H100.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…xport

transformer_sana_wm.py:
* License header switched to the "HuggingFace Team and SANA-WM Authors"
  style used by merged sana_video.
* Imports rewritten in stdlib -> third-party -> diffusers order; use
  diffusers `from ...utils import logging` instead of stdlib `logging`.
* Fix 9 `Optional[X]` annotations written as `X or None` (Python's `or`
  short-circuits and silently returns `X`).
* Fix two `assert (cond, msg)` tuple-asserts in PatchEmbedMS3D.forward
  that always pass (SyntaxWarning at import time).
* Remove duplicate `__all__` declarations (the second silently overwrote
  the first).
* Remove dead `reset_bn` (imports a nonexistent `packages.apps.utils`,
  would crash on call).
* Remove the duplicate `logger = logging.getLogger(__name__)` further
  down in the file.

transformer_sana_wm_kernels.py:
* License header normalized; collapse three duplicate triton/torch import
  blocks into one.

pipeline_sana_wm.py:
* License header normalized.
* `_decode_latents` now returns `(T, H, W, 3)` float in [0, 1], matching
  the diffusers convention used by `VideoProcessor`. Returning uint8
  silently broke `export_to_video`: it does `frame * 255` assuming float
  input, so uint8 overflows to `(-x) mod 256` and inverts colors.
* `__call__` converts to PIL/uint8 only when `output_type="pil"`.
* Intrinsics argument now accepts (4,), (F, 4), (3, 3), and (F, 3, 3)
  forms (auto-extracts fx, fy, cx, cy from a 3x3 K) and auto-trims to
  `num_frames` when a longer-than-needed trajectory is passed.
* Inline `retrieve_timesteps` with the standard `# Copied from
  diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps`
  marker, matching merged sana_video.
* Docstrings + EXAMPLE_DOC_STRING updated to reflect the new return type.

pipeline_output.py:
* Update `frames` field docstring to describe the new float [0, 1] return.

refiner.py, cam_utils.py, scripts/sana_wm/convert_sana_wm_to_diffusers.py:
* License headers normalized.

Docs:
* New `docs/source/en/api/pipelines/sana_wm.md` and
  `docs/source/en/api/models/sana_wm_transformer3d.md`, modeled on
  sana_video.md / sana_video_transformer3d.md, wired into
  `docs/source/en/_toctree.yml` under Models and Pipelines.

5s end-to-end smoke test (81 frames @ 16fps, 30 stage-1 steps + 3-step
LTX-2 refiner) passes on 1x H100 80GB with `enable_model_cpu_offload`.
Round-trip diff vs raw float frames is 2.06/255 mean (h264 lossy noise),
confirming the export_to_video fix.
…+ KV cache hooks)

The first cleanup pass only kept the legacy single-shot refiner path. That
path is what the model was *not* trained on — its docstring even says
"feeding the full sequence at once is out-of-distribution" — and its cost
is O(T^2) attention over the full latent volume, which made longer videos
unusable (~21 min per refiner step at 321 frames on an H100).

Port the chunk-causal AR mode from the upstream reference so the refiner
matches the training contract:

* `refine_latents` now defaults to `block_size=3, kv_max_frames=11`
  (the canonical AR recipe). Pass `block_size=None` to fall back to the
  legacy single-shot path.
* New `_refine_latents_ar` + `_RefinerChunkRunner` orchestrate the sliding
  window: pre-capture pre-RoPE sink K/V on `z_sana[:source_sink_frames]`
  at sigma=0, then for each `block_size`-frame chunk run a 3-step Euler
  with prefix `{sink_k_pre, sink_v, sink_pe, history_k, history_v}` and
  capture post-RoPE K/V to feed the next window. History is bounded to
  `kv_max_frames - source_sink_frames` so per-block compute is constant.
* New `_predict_x0_active_block` runs the transformer on the active block
  only (Q from active, K/V from prefix+active).
* New `_capture_block_kv` runs sigma=0 forward with a pre_rope/post_rope
  capture flag set on each `attn1`.
* New `_forward_video_only_with_rope` takes a pre-built RoPE so each block
  can use absolute frame positions in the source video.
* `_streaming_self_attention` extended with the `_kv_cache_capture`,
  `_tf_capture_kv`, `_tf_kv_prefix` hook contract that AR mode uses to
  inject and capture K/V on each block.
* New helpers: `_build_rotary_emb_for_absolute_positions`,
  `_set_kv_prefix_on_blocks`, `_clear_kv_prefix_on_blocks`,
  `_set_capture_flag_on_blocks`, `_collect_captured_kv_from_blocks`.
* `_encode_prompt` now also moves the Gemma-3 text encoder back to CPU
  after producing the embeds — otherwise it stays resident through the
  entire AR loop and gates how much GPU memory the refiner transformer
  has left.

Module-level docstring updated to document both modes; existing
single-shot path preserved verbatim.
…eemption)

The AR refiner is expensive (~3-5 min per block) and the refinement loop
ran end-to-end has no in-progress state to recover, so a SLURM preemption
mid-refinement loses all progress. With the canonical
``block_size=3, kv_max_frames=11`` setup, refining a 50s video is 34
blocks of work that has to make it through without preemption on a
backfill queue.

Add per-block atomic checkpointing:

* ``SanaWMLTX2Refiner.refine_latents(checkpoint_dir=Path)`` and
  ``_refine_latents_ar`` accept a directory. After each completed AR
  block, the AR loop writes ``checkpoint_dir/state.pt`` atomically
  (tmp + os.replace).
* The payload is ``{block_idx_done, n_blocks, sink_size, block_size,
  output_shape, output, runner_state}``. ``runner_state`` is a CPU snapshot
  of the runner's ``_sink_kv_pre``, ``_history_kv_post``,
  ``_history_frames`` and ``torch.Generator`` state.
* On entry, if ``state.pt`` exists with a compatible shape signature, the
  AR loop loads the persisted output tensor + runner state and resumes
  from ``block_idx_done + 1`` instead of recomputing from scratch.
* ``SanaWMPipeline.__call__(refiner_checkpoint_dir=...)`` plumbs the
  directory through to the refiner.

Checkpoint size: ~output_volume + sink_KV (~360MB for 50 layers) +
rolling history KV (~3-4GB at full capacity) — saved once per block,
total per-block save overhead ~10s on lustre.
@github-actions github-actions Bot added size/L PR with diff > 200 LOC documentation Improvements or additions to documentation models pipelines and removed size/L PR with diff > 200 LOC labels Jun 7, 2026
@github-actions github-actions Bot added the size/L PR with diff > 200 LOC label Jun 9, 2026
* CPU unit tests for cam_utils helpers (action DSL → c2w, intrinsics
  rescale-for-crop, resize+center-crop, snap_num_frames 8k+1 rounding).
* Public-surface registration tests (top-level diffusers symbols,
  SanaWMPipelineOutput dataclass shape, refiner signature has AR defaults
  + checkpoint_dir, pipeline __call__ accepts c2w/action/intrinsics/
  refiner_checkpoint_dir).
* @slow @require_torch_accelerator integration stub for an end-to-end I2V
  against the public checkpoint, currently @unittest.skip — wires up the
  nightly GPU path without exploding regular CI.

SanaWMTransformer3DModel has hardcoded depth/hidden_size/num_heads inside
its inner SanaMSVideoCamCtrl (not exposed through register_to_config), so
the usual PipelineTesterMixin small-config fast tests aren't applicable
without a transformer refactor (followup PR).
@github-actions github-actions Bot added the tests label Jun 9, 2026
@dg845 dg845 requested review from dg845 and yiyixuxu June 12, 2026 03:54
@dg845

dg845 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

As a preliminary comment, would it be possible to use PyTorch ops instead of custom Triton kernels (or add pure PyTorch fallback paths) for now? We will work on supporting the custom kernels through kernels. CC @sayakpaul

@lawrence-cj

Copy link
Copy Markdown
Contributor Author

As a preliminary comment, would it be possible to use PyTorch ops instead of custom Triton kernels (or add pure PyTorch fallback paths) for now? We will work on supporting the custom kernels through kernels. CC @sayakpaul

Yes, love to do that.

…ttention

`transformer_sana_wm_kernels.py` previously did a hard `import triton`
at the top of the file. That blocked importing the SANA-WM transformer
on any environment without Triton (CPU-only, ROCm without Triton,
older Triton, etc.), even though the model has pure-PyTorch attention
classes for every `*Triton` variant.

Make Triton optional and have the dispatcher transparently fall back:

* Wrap `import triton` / `import triton.language as tl` in try/except.
  When unavailable, install a shim where `@triton.jit` is a no-op so
  the kernel function definitions still load (they just aren't compiled
  by Triton). Module-level `triton.X` / `tl.X` lookups return a
  self-shimming sentinel so signature parsing doesn't blow up either.
* Add `is_triton_available()` + `_require_triton(entry_point)`. The four
  Triton-backed entry points called by the model (`fused_qk_inv_rms`,
  `fused_bigdn_func`, `cam_prep_func`, `cam_scan_bidi_chunkwise`) now
  raise a clear RuntimeError on a Triton-less host with a hint to use
  the pure-PyTorch attention variants — but the dispatcher does this
  automatically (see below) so users shouldn't ever see it.
* Delete the leftover duplicate `import torch / triton / triton.language`
  block at line 262 (left over from the upstream port).
* Register `BidirectionalGDNUCPESinglePathLiteLA` in `ATTENTION_BLOCKS`
  so the fallback chain can find it.
* New `_resolve_attention_block(name, role)` walks the requested class's
  MRO at dispatch time. If Triton isn't usable AND the requested class
  name ends in `Triton`, route to the closest registered non-`Triton`
  ancestor (BidirectionalGDNUCPESinglePathLiteLABothTriton ->
  BidirectionalGDNUCPESinglePathLiteLA, etc.) and log a one-shot warning.
* Rewire both `SanaVideoMSCamCtrlBlock` dispatch sites to use
  `_resolve_attention_block` for the GDN+UCPE camera branch and the main
  attention branch (the `BidirectionalSoftmaxUCPESinglePathLiteLA` branch
  doesn't use Triton at all so it stays hard-coded).

Tests:
* `test_kernels_module_imports_with_triton_hidden` — reloads the kernels
  module with `sys.modules['triton'] = None` and verifies the module
  imports, `is_triton_available()` is False, and the pure-PyTorch helpers
  remain callable.
* `test_resolve_attention_block_cpu_fallback` — on a CPU-only host, the
  three `*Triton` attn types resolve to the correct non-Triton ancestor.
* `test_triton_entry_point_raises_clean_error_without_triton` — verifies
  the `_require_triton` guard yields a RuntimeError that mentions Triton.
@lawrence-cj

lawrence-cj commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

Done in c0712d3f8 — Triton is now optional, with an automatic pure-PyTorch fallback at dispatch time. Mapping when Triton isn't usable:

Requested Falls back to
BidirectionalGDNTriton BidirectionalGDN
BidirectionalGDNUCPESinglePathLiteLATriton BidirectionalGDNUCPESinglePathLiteLA
BidirectionalGDNUCPESinglePathLiteLABothTriton BidirectionalGDNUCPESinglePathLiteLA

Triton remains the default on CUDA + Triton ≥ 3. CPU tests added under tests/pipelines/sana_wm/.

@lawrence-cj

Copy link
Copy Markdown
Contributor Author

@dg845 @yiyixuxu Gentle ping here.

lawrence-cj and others added 4 commits June 18, 2026 11:26
Three CI checks were failing on the PR:

1. `check_code_quality` (43 ruff errors): mix of unused imports / import
   sorting / E731 lambdas (auto-fixable) plus a handful of F821 dead-code
   references inherited from the upstream research codebase (`xformers.*`
   inside `if _xformers_available:` blocks, an undefined `BlockHook` type
   annotation, two `x_sa`/`mlp_out` references in a block forward whose
   live assignment was already overridden by subclasses). Ran `ruff check
   --fix --unsafe-fixes` + `ruff format`, fixed the type annotation
   manually, and added targeted `# noqa: F821` markers on the conditionally
   unreachable lines.

2. `check_torch_dependencies`: `transformer_sana_wm.py` hard-imported
   `einops`, `fla`, `timm`, `termcolor`. The minimum-deps CI environment
   doesn't have them, and diffusers' lazy loader rewrites `ModuleNotFoundError`
   as `RuntimeError` so `test_pipeline_imports` blew up. Wrapped each of
   the four optional imports in a try/except shim — `rearrange`/
   `ShortConvolution`/`DropPath`/`Attention_`/`Mlp` become placeholders
   that raise a clear `ImportError` on construction, `colored` falls back
   to plain text. Class bodies that subclass these still parse at module
   load, so `import diffusers.models.transformers.transformer_sana_wm`
   succeeds anywhere. Same treatment for the kernels file's
   `from einops import rearrange, repeat`.

3. `build_pr_documentation`: doc-builder imported `SanaWMTransformer3DModel`
   from `diffusers.models.transformers` (not the diffusers top level) and
   that subpackage's `__init__.py` was missing the entry. Added the import.
* `doc-builder style src/diffusers docs/source --max_len 119` rewraps
  docstrings in the six SANA-WM files (transformer, kernels, pipeline,
  refiner, output, cam_utils) to the repo-wide 119-column limit. No
  behaviour change — purely whitespace inside docstrings.
* `make fix-copies` regenerates `dummy_pt_objects.py` and
  `dummy_torch_and_transformers_objects.py` to add `DummyObject` stubs
  for the three new public classes (`SanaWMTransformer3DModel`,
  `SanaWMPipeline`, `SanaWMLTX2Refiner`), so `from diffusers import …`
  gives the standard "missing backend" message on installs without
  torch / transformers.

Verified: `make quality` passes (ruff check, ruff format check,
doc-builder style check_only, check_doc_toc). Test suite still
15 passed / 1 skipped.
@github-actions github-actions Bot added the utils label Jun 25, 2026
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dg845

dg845 commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Can you remove dead code (that isn't used by any existing Sana-WM checkpoint) from the PR so that it is easier to review? For example, in transformer_sana_wm.py it looks like the following might be dead code:

  • the SanaMS, SanaBlock, and SanaMSBlock modules
  • some attention classes (e.g. ChunkCausalAttention, CachedCausalAttention, etc.)
  • the PAG processors (PAGCFGIdentitySelfAttnProcessorLiteLA, SelfAttnProcessorLiteLA, etc.)
  • Layer implementations such as ChunkGLUMBConvTemp, MBConvPreGLU, MaskFinalLayer, etc.
  • a bunch of helper functions (e.g. set_grad_checkpoint, is_chunk_causal_request, etc.)

Comment thread src/diffusers/models/transformers/transformer_sana_wm.py Outdated
Per @dg845's review on `transformer_sana_wm.py:44`. The 9 call sites all
match well-known patterns that have one-liner torch equivalents:

  rearrange(s, "b d -> (b d)")              -> s.reshape(b * d)
  rearrange(x, "(b d) d2 -> b (d d2)", ...) -> x.reshape(b, d * d2)
  rearrange(R, "b t h w i j -> b t h w j i") -> R.transpose(-1, -2)
  repeat(x, "b h w c -> b t h w c", t=T)    -> x.unsqueeze(1).expand(-1, T, -1, -1, -1)
  repeat(x, "b t c -> b t h w c", h=H, w=W) -> x[:, :, None, None, :].expand(-1, -1, H, W, -1)
  repeat(x, "... -> b ...", b=B)            -> x.unsqueeze(0).expand(B, *x.shape)
  repeat(x, "H W C -> B T H W C", B, T)     -> x[None, None].expand(B, T, -1, -1, -1)
  repeat(x, "B H W C -> B T H W C", T)      -> x.unsqueeze(1).expand(-1, T, -1, -1, -1)

Each replacement is bit-identical to the einops original — verified
against a fresh `einops` install on random tensors before swapping. The
optional `from einops import ...` shim block is gone from both
`transformer_sana_wm.py` and `transformer_sana_wm_kernels.py`.
Drop 23 unused symbols (1451 lines) from `transformer_sana_wm.py` that
aren't reachable from the public SANA-WM checkpoint's
`SanaMSVideoCamCtrl` -> `SanaVideoMSCamCtrlBlock` -> GDN/UCPE attention
path. Each was verified to have zero call sites outside of its own
definition (or only within other-deleted items).

Classes:
* `SanaMS`, `SanaMSBlock` — alternative `Sana` subclass + block, not
  used by the SANA-WM checkpoint (which goes through `SanaMSVideoCamCtrl`).
* `ChunkCausalAttention`, `CachedCausalAttention`,
  `ChunkedLiteLAReLURope`, `LiteLAReLURope` — chunk-causal / cached
  attention variants and their common base; SANA-WM uses the bidi GDN
  path. `LiteLAReLURope` had only the three (now-deleted) subclasses
  referencing it.
* `PAGCFGIdentitySelfAttnProcessorLiteLA`,
  `PAGIdentitySelfAttnProcessorLiteLA`, `SelfAttnProcessorLiteLA`,
  `SelfAttnProcessorLiteLAReLURope` — PAG processors; we don't expose
  PAG in the SANA-WM pipeline.
* `ChunkGLUMBConvTemp`, `CachedGLUMBConvTemp`, `MBConvPreGLU` —
  alternative FFN/conv blocks; the checkpoint uses `GLUMBConvTemp`.
* `MaskFinalLayer`, `DecoderLayer` — alternative final layers; the
  checkpoint uses `T2IFinalLayer`.
* `LabelEmbedder`, `CaptionEmbedderDoubleBr` — alternative embedders;
  the checkpoint uses `CaptionEmbedder`.

Helpers:
* `set_grad_checkpoint`, `prepare_prompt_ar`, `resize_and_crop_tensor`,
  `generate_temporal_head_mask_mod`, `is_chunk_causal_request`,
  `get_chunk_index_from_config` — training-only or
  chunk-causal-mode helpers with no inference call sites.

Verified after the diff:
* `make quality` clean.
* CPU test suite 15 passed / 1 skipped (slow GPU integration).
* `import diffusers ; SanaWMPipeline / SanaWMTransformer3DModel /
  SanaWMLTX2Refiner` resolve identically.
@lawrence-cj

lawrence-cj commented Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

Addressed both in 1b8134444:

  • einops removed (88c7c3d70): the 9 call sites all map to .reshape / .transpose(-1, -2) / .unsqueeze().expand().
  • Dead code removed (2cb863050): −1475 lines in transformer_sana_wm.py (9010 → 7559). Dropped SanaMS / SanaMSBlock, the chunk-causal attention path (ChunkCausalAttention, CachedCausalAttention, ChunkedLiteLAReLURope, LiteLAReLURope), the PAG processors, ChunkGLUMBConvTemp / CachedGLUMBConvTemp / MBConvPreGLU, MaskFinalLayer / DecoderLayer / LabelEmbedder / CaptionEmbedderDoubleBr, and the unused helpers (set_grad_checkpoint, prepare_prompt_ar, resize_and_crop_tensor, etc.).

dg845 and others added 3 commits June 25, 2026 16:32
Two PR-CI failures, both ours:

* `check_torch_dependencies`: line 33 hard-imported `transformers`, which
  isn't present in the minimum-deps environment. Move the lone
  `AutoModelForCausalLM` use site inside `initialize_gemma_params` (a
  training-only helper).
* `check_repository_consistency`: `SanaWMTransformer3DModel.forward`'s
  docstring was missing entries for `mask` and `return_dict`. Added.
@lawrence-cj

Copy link
Copy Markdown
Contributor Author

@dg845 Hi, gentle ping. Seems the PR CI test failing is not due to our code. Anything else to do?

@dg845

dg845 commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Hi @lawrence-cj, I am reviewing the code and will try to have a full review out soon. Thanks for your patience!

Comment on lines +176 to +177
vae_scale_factor_spatial: int = 32
vae_scale_factor_temporal: int = 8

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should get these attributes from the vae component like LTX2Pipeline does rather than hardcoding them here. This would make it easier for the pipeline to support different VAEs.

# Stage-1 DiT sampling — LTX-style per-token timesteps
# ------------------------------------------------------------------

def _sample_stage1(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should inline the logic in _sample_stage1 into __call__, as this follows the design used by other pipelines. We could then define standard methods like prepare_latents, etc. to organize it.

latent_channels = first_latent.shape[1]
do_cfg = guidance_scale > 1.0

scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the self.scheduler component rather than creating a new scheduler here.

**cam_kwargs,
}

for t in tqdm(timesteps, disable=os.getenv("DPM_TQDM", "False") == "True"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use self.progress_bar here. For example, this is what LTX2Pipeline does:

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

I think we should also respect progress bar config changes via DiffusionPipeline.set_progress_bar_config. For example, users can use it to disable the progress bar, as some tests do:

pipe.set_progress_bar_config(disable=None)

so I think we don't need to have an explicit condition for disable here.

Comment on lines +530 to +539
if isinstance(image, (str, Path)):
image = PIL.Image.open(image).convert("RGB")

if (c2w is None) == (action is None):
raise ValueError("Provide exactly one of `c2w` or `action`.")
if action is not None:
c2w = action_string_to_c2w(action)
c2w = np.asarray(c2w, dtype=np.float32)
if c2w.ndim != 3 or c2w.shape[1:] != (4, 4):
raise ValueError(f"`c2w` must be `(F, 4, 4)`; got {c2w.shape}.")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move the validation checks here (and below) into a separate check_inputs method.

Comment on lines +570 to +571
cropped, src_size, resized_size, crop_offset = resize_and_center_crop(image, height, width)
intr = transform_intrinsics_for_crop(intr, src_size, resized_size, crop_offset)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to move the image and intrinsics pre-processing into a custom VaeImageProcessor subclass, similar to what Wan Animate does:

class WanAnimateImageProcessor(VaeImageProcessor):

For example, I think resize_and_center_crop, estimate_intrinsics_with_pi3x, transform_intrinsics_for_crop, and the image normalization code in _encode_first_frame could all potentially be refactored into a custom image processor. CC @yiyixuxu

c2w, intr, (height, width), device=device, dtype=dtype, do_cfg=guidance_scale > 1.0
)

generator = torch.Generator(device=device).manual_seed(seed)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make generator a __call__ argument, in line with other pipelines like LTX-2:

generator: torch.Generator | list[torch.Generator] | None = None,

Comment on lines +609 to +613
return (
SanaWMPipelineOutput(frames=latents.cpu(), c2w=c2w, latent=latents.cpu())
if return_dict
else (latents.cpu(),)
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can remove the cpu() casts here since we don't normally place the output latents on CPU when output_type="latent".

Comment on lines +634 to +640
if output_type == "pil":
video_uint8 = (video.numpy() * 255.0).round().clip(0, 255).astype(np.uint8)
frames: list | np.ndarray = [PIL.Image.fromarray(f) for f in video_uint8]
elif output_type == "np":
frames = video.numpy()
else:
frames = video

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a VideoProcessor to post-process the generated video latents rather than re-implementing the post-processing logic here.

@@ -0,0 +1,63 @@
# SANA-WM diffusers pipeline

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move the content here to the docs (e.g. docs/source/en/api/pipelines/sana_wm.md).

STAGE_2_DISTILLED_SIGMA_VALUES: tuple[float, ...] = (0.909375, 0.725, 0.421875, 0.0)


class SanaWMLTX2Refiner(ModelMixin, ConfigMixin):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SanaWMLTX2Refiner makes more sense as a standalone pipeline (that is, as a DiffusionPipeline subclass rather than a ModelMixin subclass) since it wants to have components (e.g. self.transformer, self.connectors, etc.) and implements a denoising loop (e.g. in refine_latents). Can you refactor it to be a pipeline (using a scheduler to implement the denoising steps)?

@dg845 dg845 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I left an initial design review for the pipeline code, am still working on reviewing the modeling code.

@lawrence-cj

Copy link
Copy Markdown
Contributor Author

Thanks for the PR! I left an initial design review for the pipeline code, am still working on reviewing the modeling code.

Thanks so much for the review. I'll keep reformatting the code as your requirements.

Per @dg845's review comments on pipeline_sana_wm.py:

* Read the VAE spatial/temporal strides from the ``vae`` component
  (``vae_spatial_compression_ratio`` / ``vae_temporal_compression_ratio``)
  instead of hardcoding 32/8, mirroring LTX2Pipeline.
* Inline the stage-1 sampling loop into ``__call__`` and factor the noise
  init into a standard ``prepare_latents`` method.
* Use ``self.scheduler`` for the flow-matching Euler steps instead of
  constructing a new scheduler per call.
* Drive the sampling loop with ``self.progress_bar`` (respects
  ``set_progress_bar_config``) instead of a bare tqdm.
* Move input validation/normalization into a ``check_inputs`` method.
* Move first-frame resize+center-crop, [-1, 1] normalization and the
  intrinsics rescale into a ``SanaWMImageProcessor(VaeImageProcessor)``
  subclass (new ``image_processor.py``).
* Add ``generator`` as a ``__call__`` argument (``seed`` kept as a
  convenience shortcut).
* Post-process decoded latents with ``VideoProcessor.postprocess_video``
  rather than a hand-rolled conversion; drop the ``.cpu()`` casts on the
  ``output_type="latent"`` path.
* Move the pipeline README into the docs (docs/.../sana_wm.md); delete
  the in-package README.
Per @dg845's review: the refiner has components (transformer, connectors,
text encoder, tokenizer) and runs a denoising loop, so it fits better as a
pipeline than a ModelMixin.

* SanaWMLTX2Refiner now subclasses DiffusionPipeline and registers its
  components via ``register_modules`` (dropping the bespoke
  ``from_pretrained`` / ``save_pretrained``); standard load/save now handle
  the ``refiner/`` subfolder.
* Add a ``FlowMatchEulerDiscreteScheduler`` component (shift=1.0) and drive
  the Euler steps through ``scheduler.step`` / ``scheduler.scale_noise``
  (single-shot and per-AR-block) instead of hand-rolled updates. Numerically
  equivalent to the previous flow-matching update.
* Rename the entry point ``refine_latents`` -> ``__call__``; add a ``device``
  arg so the parent can hand it the execution device without a bulk move.
* SanaWMPipeline: keep the refiner as an optional nested component; free the
  parent's GPU weights before running it (it manages its own sub-module
  placement) and bring the VAE back for decode.
* Conversion script emits the new refiner layout (model_index.json +
  scheduler/ + tokenizer/); test asserts the refiner is a DiffusionPipeline
  with the canonical AR ``__call__`` defaults.

Validated end-to-end on 1xH100 (stage-1 + refiner on the official demo,
coherent video output).
@lawrence-cj

Copy link
Copy Markdown
Contributor Author

Pushed f259221 + 6317ce3 addressing the latest review:

  • SanaWMPipeline cleanup: VAE strides read from the vae component; check_inputs + prepare_latents; stage-1 loop inlined into __call__ using self.scheduler and self.progress_bar; generator is now a __call__ arg; first-frame/intrinsics preprocessing moved into a SanaWMImageProcessor(VaeImageProcessor); output post-processed via VideoProcessor; README moved into the docs.
  • Refiner → standalone pipeline: SanaWMLTX2Refiner now subclasses DiffusionPipeline (components via register_modules, standard load/save) and drives its Euler steps through a FlowMatchEulerDiscreteScheduler. Kept as an optional nested component of SanaWMPipeline; validated end-to-end on 1×H100.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants